In [ ]:
from tensorflow import keras
from tensorflow.keras import layers
import pathlib
from tensorflow.keras.utils import image_dataset_from_directory

import pandas as pd
import pathlib
from pathlib import Path

import numpy as np
import pandas as pd

# plotting modules
from matplotlib import pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

import plotly as plotly
plotly.offline.init_notebook_mode()

from tensorflow import keras
from tensorflow.keras import layers

import tensorflow as tf
from keras.utils import to_categorical
from keras.models import load_model
from PIL import Image

import plotly.graph_objects as go
from tensorflow.keras.models import Sequential
from keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Dense
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, precision_recall_curve, ConfusionMatrixDisplay

Framing the Problem¶

The goal of this notebook is to implement one of the features of our proposed product that can be solved by machine learning.

This notebook is for Plant Disease Identification.

We will try to use a convolutional neural network to predict the disease on the plant.

Getting the data¶

The dataset that we will be training our models on is gotten from Kaggle - https://www.kaggle.com/datasets/sadmansakibmahi/plant-disease-expert/data

In [ ]:
data_folder = pathlib.Path("../../../../../../Downloads/datasets/plant-disease")
In [ ]:
def create_image_dataframe(folder):
    data = {'ImagePath': [], 'ClassLabel': []}
    for class_folder in folder.iterdir():
        if class_folder.is_dir():
            for img_path in class_folder.iterdir():
                data['ImagePath'].append(img_path)
                data['ClassLabel'].append(class_folder.name)
    return pd.DataFrame(data)
In [ ]:
df = create_image_dataframe(data_folder)
In [ ]:
df
Out[ ]:
ImagePath ClassLabel
0 ..\..\..\..\..\..\Downloads\datasets\plant-dis... algal leaf in tea
1 ..\..\..\..\..\..\Downloads\datasets\plant-dis... algal leaf in tea
2 ..\..\..\..\..\..\Downloads\datasets\plant-dis... algal leaf in tea
3 ..\..\..\..\..\..\Downloads\datasets\plant-dis... algal leaf in tea
4 ..\..\..\..\..\..\Downloads\datasets\plant-dis... algal leaf in tea
... ... ...
128077 ..\..\..\..\..\..\Downloads\datasets\plant-dis... Tomato Tomato mosaic virus
128078 ..\..\..\..\..\..\Downloads\datasets\plant-dis... Tomato Tomato mosaic virus
128079 ..\..\..\..\..\..\Downloads\datasets\plant-dis... Tomato Tomato mosaic virus
128080 ..\..\..\..\..\..\Downloads\datasets\plant-dis... Tomato Tomato mosaic virus
128081 ..\..\..\..\..\..\Downloads\datasets\plant-dis... Tomato Tomato mosaic virus

128082 rows × 2 columns

As we can see from the above, we have 199,665 images in all the folders combined.

The class labels represent each of the folders we have in our downloaded folder.

In [ ]:
len(df['ClassLabel'].unique())
Out[ ]:
43

58 classes, that is 58 different kind of 'states' of some set of plants. 'States' and not diseases because of the plants in the dataset are healthy.

In [ ]:
df['ClassLabel'].value_counts()
Out[ ]:
ClassLabel
Grape Esca Black Measles                       13284
Grape Black rot                                11328
Grape Leaf blight Isariopsis Leaf Spot         10332
Tomato Bacterial spot                           6381
Apple Apple scab                                6048
Apple Black rot                                 5964
Tomato Late blight                              5727
Tomato Septoria leaf spot                       5313
Tomato Spider mites Two spotted spider mite     5028
Tomato Target Spot                              4212
Apple healthy                                   3948
Common Rust in corn Leaf                        3918
Tomato healthy                                  3819
Pepper bell healthy                             3549
Blight in corn Leaf                             3438
Potato Early blight                             3000
Potato Late blight                              3000
Tomato Early blight                             3000
Pepper bell Bacterial spot                      2988
Tomato Leaf Mold                                2856
Corn (maize) healthy                            2790
Strawberry Leaf scorch                          2664
Apple Cedar apple rust                          2640
Cherry (including sour) Powdery mildew          2526
Cherry (including_sour) healthy                 2052
Gray Leaf Spot in corn Leaf                     1722
Tomato Tomato mosaic virus                      1119
Strawberry healthy                              1095
Grape healthy                                   1017
red leaf spot in tea                             429
Potato healthy                                   366
algal leaf in tea                                339
brown blight in tea                              339
corn crop                                        312
anthracnose in tea                               300
bird eye spot in tea                             300
healthy tea leaf                                 222
potato hollow heart                              180
potato crop                                      120
Leaf smut in rice leaf                           120
Brown spot in rice leaf                          120
Bacterial leaf blight in rice leaf               120
tomato canker                                     57
Name: count, dtype: int64

Exploratory Data Analysis¶

Let us put our dataset to view.

In [ ]:
import matplotlib.pyplot as plt
from PIL import Image

grouped = df.groupby('ClassLabel')

# Plot the first 5 images for each class
for label, group in grouped:
    print(label)
    print('=====================')
    
    # Create a figure with 5 subplots (1 row, 5 columns)
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    
    # Iterate over the first 5 rows of the group
    for i in range(min(5, len(group))):  # This ensures we don't go out of bounds if there are less than 5 images
        img_path = group['ImagePath'].iloc[i]
        img = Image.open(img_path)
        
        # Plot the image in the i-th subplot
        axes[i].imshow(img)
        axes[i].set_title(label)
        axes[i].axis('off')
    
    # Hide any unused subplots if the group has less than 5 images
    if len(group) < 5:
        for j in range(len(group), 5):
            axes[j].axis('off')
    
    plt.show()
Apple Apple scab
=====================
No description has been provided for this image
Apple Black rot
=====================
No description has been provided for this image
Apple Cedar apple rust
=====================
No description has been provided for this image
Apple healthy
=====================
No description has been provided for this image
Bacterial leaf blight in rice leaf
=====================
No description has been provided for this image
Blight in corn Leaf
=====================
No description has been provided for this image
Brown spot in rice leaf
=====================
No description has been provided for this image
Cherry (including sour) Powdery mildew
=====================
No description has been provided for this image
Cherry (including_sour) healthy
=====================
No description has been provided for this image
Common Rust in corn Leaf
=====================
No description has been provided for this image
Corn (maize) healthy
=====================
No description has been provided for this image
Grape Black rot
=====================
No description has been provided for this image
Grape Esca Black Measles
=====================
No description has been provided for this image
Grape Leaf blight Isariopsis Leaf Spot
=====================
No description has been provided for this image
Grape healthy
=====================
No description has been provided for this image
Gray Leaf Spot in corn Leaf
=====================
No description has been provided for this image
Leaf smut in rice leaf
=====================
No description has been provided for this image
Pepper bell Bacterial spot
=====================
No description has been provided for this image
Pepper bell healthy
=====================
No description has been provided for this image
Potato Early blight
=====================
No description has been provided for this image
Potato Late blight
=====================
No description has been provided for this image
Potato healthy
=====================
No description has been provided for this image
Strawberry Leaf scorch
=====================
No description has been provided for this image
Strawberry healthy
=====================
No description has been provided for this image
Tomato Bacterial spot
=====================
No description has been provided for this image
Tomato Early blight
=====================
No description has been provided for this image
Tomato Late blight
=====================
No description has been provided for this image
Tomato Leaf Mold
=====================
No description has been provided for this image
Tomato Septoria leaf spot
=====================
No description has been provided for this image
Tomato Spider mites Two spotted spider mite
=====================
No description has been provided for this image
Tomato Target Spot
=====================
No description has been provided for this image
Tomato Tomato mosaic virus
=====================
No description has been provided for this image
Tomato healthy
=====================
No description has been provided for this image
algal leaf in tea
=====================
No description has been provided for this image
anthracnose in tea
=====================
No description has been provided for this image
bird eye spot in tea
=====================
No description has been provided for this image
brown blight in tea
=====================
No description has been provided for this image
corn crop
=====================
No description has been provided for this image
healthy tea leaf
=====================
No description has been provided for this image
potato crop
=====================
No description has been provided for this image
potato hollow heart
=====================
No description has been provided for this image
red leaf spot in tea
=====================
No description has been provided for this image
tomato canker
=====================
No description has been provided for this image

Modeling¶

Before we train our model, we need to prepare our images to the format that the model expects.

In [ ]:
df['ImagePath'] = df['ImagePath'].astype(str)
grouped = df.groupby('ClassLabel')

train_df = pd.DataFrame()
val_df = pd.DataFrame()
test_df = pd.DataFrame()

for _, group in grouped:
    # Split into train and temporary test
    train_tmp, test_tmp = train_test_split(group, test_size=0.3, random_state=42)  # 70% train, 30% temp test
    
    # Split the temporary test into actual validation and test
    val_tmp, test_final = train_test_split(test_tmp, test_size=0.5, random_state=42)  # Split 30% into 15% val, 15% test

    # Append to the respective DataFrames
    train_df = pd.concat([train_df, train_tmp])
    val_df = pd.concat([val_df, val_tmp])
    test_df = pd.concat([test_df, test_final])

# Now, you have train_df, val_df, and test_df
print(f"Training Set: {train_df.shape[0]} samples")
print(f"Validation Set: {val_df.shape[0]} samples")
print(f"Test Set: {test_df.shape[0]} samples")
Training Set: 89642 samples
Validation Set: 19213 samples
Test Set: 19227 samples

Here, we are applying some data augmentation to the dataset to help improve the performance of the model.

We apply some rotation to the images, move the width and heirght around a bit, also flipping them and inverting them.

Another very important step you can see below is that we have rescaled the rgb values of the images from 0 - 256 to values just between 0 and 1. This is an important step because it makes sure that the model doesn't assign more weights to the pixels with very high values.

In [ ]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def add_noise(image):
    noise = tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=0.05, dtype=tf.float32)
    return image + noise

def color_jitter(image):
    image = tf.image.random_hue(image, 0.08)
    image = tf.image.random_saturation(image, 0.6, 1.6)
    image = tf.image.random_brightness(image, 0.05)
    image = tf.image.random_contrast(image, 0.7, 1.3)
    return image

def resize_and_crop(image):
    # Resize the image to a larger size to ensure we can crop properly
    image = tf.image.resize(image, [180, 180])
    # Randomly crop the image to the target size
    image = tf.image.random_crop(image, size=[150, 150, 3])
    return image

def preprocessing_function(image):
    image = resize_and_crop(image)
    image = add_noise(image)
    image = color_jitter(image)
    return image

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,  # Randomly rotate images in the range (degrees, 0 to 180)
    width_shift_range=0.2,  # Randomly shift images horizontally (fraction of total width)
    height_shift_range=0.2,  # Randomly shift images vertically (fraction of total height)
    zoom_range=0.2,  # Randomly zoom images
    brightness_range=[0.8, 1.2],  # Randomly change brightness
    horizontal_flip=True,  # Randomly flip images horizontally
    vertical_flip=True,  # Randomly flip images vertically
    preprocessing_function=preprocessing_function  # Custom preprocessing
)

val_datagen = ImageDataGenerator(rescale=1./255)
In [ ]:
train_generator = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    x_col="ImagePath",
    y_col="ClassLabel",
    batch_size=32,
    seed=42,
    shuffle=True,
    class_mode="categorical",
    target_size=(150,150))

valid_generator = val_datagen.flow_from_dataframe(
    dataframe=val_df,
    x_col="ImagePath",
    y_col="ClassLabel",
    batch_size=32,
    seed=42,
    shuffle=True,
    class_mode="categorical",
    target_size=(150,150))

test_generator = val_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col="ImagePath",
    y_col="ClassLabel",
    batch_size=32,
    seed=42,
    shuffle=False,
    class_mode="categorical",
    target_size=(150,150))
Found 89642 validated image filenames belonging to 43 classes.
Found 19213 validated image filenames belonging to 43 classes.
Found 19227 validated image filenames belonging to 43 classes.

Convolutional Neural Network Model¶

With tensorflow and keras, we will be using a fully connected neural network for our model.

In [ ]:
from tensorflow import keras
from tensorflow.keras import layers

inputs = keras.Input(shape=(150, 150, 3))
x = layers.Conv2D(filters=32, kernel_size=3, activation="relu")(inputs)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=64, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=128, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.Flatten()(x)
outputs = layers.Dense(43, activation="softmax")(x)  # Adjusted for 43 classes with softmax activation
model = keras.Model(inputs=inputs, outputs=outputs)

Model Explainability¶

Model Architecture and Training Overview¶

We've constructed a neural network model using TensorFlow and Keras, designed specifically for image classification. This model processes images sized 150x150 pixels with 3 color channels (RGB).

Model Layers:

The model starts with a series of convolutional layers, where each layer is designed to detect specific features in the image like edges, textures, or more complex patterns as we go deeper.

Each convolutional layer is followed by a max-pooling layer, which reduces the spatial dimensions of the image representations, helping to make the detection of features invariant to scale and orientation changes.

After multiple layers of convolutions and pooling, the high-level understanding of the images is flattened into a vector that serves as input to a fully connected layer.

Output Layer:

The final layer is a dense layer with 58 units, corresponding to the number of categories we want to classify. It uses the softmax activation function to output probabilities for each class, indicating the likelihood of the image belonging to each class.

Model Compilation:

The model uses the Adam optimizer for adjusting weights, which is effective and efficient for this kind of problem. It minimizes a function called categorical crossentropy, a common choice for classification tasks, which measures the difference between the predicted probabilities and the actual distribution of the labels. We track the accuracy during training as a straightforward metric to understand how well the model performs.

Training Process Enhancements¶

To optimize training and avoid common pitfalls:

ModelCheckpoint: Saves the best model as we train, ensuring that we always have the version of the model that performed best on the validation set, in case later iterations perform worse.

EarlyStopping: Monitors the model's performance on a validation set and stops training if the model's performance doesn't improve for 10 consecutive epochs. This prevents overfitting and unnecessary computation by stopping when the model isn't learning anymore.

ReduceLROnPlateau: Reduces the learning rate if the validation loss stops improving. Smaller steps in weight updates can lead to better fine-tuning and better overall model performance.

Execution¶

The model is trained using the specified training and validation datasets for 15 epochs, but training can stop early if no improvement is seen as monitored by our callbacks. This setup ensures that the training is both efficient and effective, adapting to the data as needed.

In [ ]:
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 conv2d (Conv2D)             (None, 148, 148, 32)      896       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 74, 74, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 72, 72, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 36, 36, 64)       0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 34, 34, 128)       73856     
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 17, 17, 128)      0         
 2D)                                                             
                                                                 
 conv2d_3 (Conv2D)           (None, 15, 15, 256)       295168    
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 7, 7, 256)        0         
 2D)                                                             
                                                                 
 conv2d_4 (Conv2D)           (None, 5, 5, 256)         590080    
                                                                 
 flatten (Flatten)           (None, 6400)              0         
                                                                 
 dense (Dense)               (None, 43)                275243    
                                                                 
=================================================================
Total params: 1,253,739
Trainable params: 1,253,739
Non-trainable params: 0
_________________________________________________________________
In [ ]:
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)
In [ ]:
# Define callbacks
from tensorflow.keras import optimizers, callbacks

my_callbacks = [
    callbacks.ModelCheckpoint(filepath='./models/new/disease-vanilla.h5', save_best_only=True),
    callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=1),
    callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, verbose=1, min_lr=1e-5)
]
In [ ]:
history = model.fit(
    train_generator,
    validation_data=valid_generator,
    epochs=10,
    callbacks=my_callbacks
)
Epoch 1/10
2802/2802 [==============================] - 1272s 451ms/step - loss: 2.0372 - accuracy: 0.4080 - val_loss: 1.7793 - val_accuracy: 0.5659 - lr: 0.0010
Epoch 2/10
2802/2802 [==============================] - 1230s 439ms/step - loss: 0.8408 - accuracy: 0.7311 - val_loss: 0.8972 - val_accuracy: 0.7278 - lr: 0.0010
Epoch 3/10
2802/2802 [==============================] - 1236s 441ms/step - loss: 0.5907 - accuracy: 0.8102 - val_loss: 0.4187 - val_accuracy: 0.8583 - lr: 0.0010
Epoch 4/10
2802/2802 [==============================] - 1002s 358ms/step - loss: 0.4893 - accuracy: 0.8420 - val_loss: 0.3340 - val_accuracy: 0.8904 - lr: 0.0010
Epoch 5/10
2802/2802 [==============================] - 1016s 363ms/step - loss: 0.4382 - accuracy: 0.8590 - val_loss: 0.2867 - val_accuracy: 0.9026 - lr: 0.0010
Epoch 6/10
2802/2802 [==============================] - 1013s 361ms/step - loss: 0.4025 - accuracy: 0.8695 - val_loss: 0.3748 - val_accuracy: 0.8821 - lr: 0.0010
Epoch 7/10
2802/2802 [==============================] - 1035s 370ms/step - loss: 0.3860 - accuracy: 0.8745 - val_loss: 0.3026 - val_accuracy: 0.9004 - lr: 0.0010
Epoch 8/10
2802/2802 [==============================] - 1055s 377ms/step - loss: 0.3655 - accuracy: 0.8819 - val_loss: 0.3547 - val_accuracy: 0.8826 - lr: 0.0010
Epoch 9/10
2802/2802 [==============================] - 1060s 378ms/step - loss: 0.3492 - accuracy: 0.8866 - val_loss: 0.3938 - val_accuracy: 0.8757 - lr: 0.0010
Epoch 10/10
2802/2802 [==============================] - 1076s 384ms/step - loss: 0.3416 - accuracy: 0.8890 - val_loss: 0.2176 - val_accuracy: 0.9278 - lr: 0.0010
In [ ]:
history_df = pd.DataFrame(history.history)
history_df.insert(0, 'epoch', range(1, len(history_df) + 1))
history_df
Out[ ]:
epoch loss accuracy val_loss val_accuracy lr
0 1 2.037185 0.407967 1.779317 0.565867 0.001
1 2 0.840819 0.731086 0.897182 0.727788 0.001
2 3 0.590703 0.810223 0.418683 0.858325 0.001
3 4 0.489264 0.842016 0.334048 0.890439 0.001
4 5 0.438172 0.859006 0.286736 0.902618 0.001
5 6 0.402475 0.869503 0.374806 0.882059 0.001
6 7 0.386044 0.874523 0.302613 0.900380 0.001
7 8 0.365544 0.881886 0.354745 0.882580 0.001
8 9 0.349211 0.886627 0.393828 0.875709 0.001
9 10 0.341619 0.888992 0.217575 0.927809 0.001
In [ ]:
# Create a DataFrame from the history object
history_df = pd.DataFrame(history.history)

# Plot the training and validation loss
plt.figure(figsize=(9, 5))
values = history_df['accuracy']
epochs = range(1, len(values) + 1)
plt.plot(epochs, history_df['loss'], 'bo', label='Training loss')
plt.plot(epochs, history_df['val_loss'], 'ro', label='Validation loss')

plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and validation loss')
plt.show()

# Plot the training and validation accuracy
plt.figure(figsize=(9, 5))
plt.plot(epochs, history_df['accuracy'], 'bo', label='Training accuracy')
plt.plot(epochs, history_df['val_accuracy'], 'ro', label='Validation accuracy')

plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and validation accuracy')
plt.show()
No description has been provided for this image
No description has been provided for this image

Interpreting the Model Performance¶

After training the model over 15 epochs, we were able to get an accuracy score of 96.9% on our validation dataset.

Evaluating the Model on the Test Data¶

In [ ]:
# evaluate the model
model.evaluate(test_generator)
601/601 [==============================] - 37s 61ms/step - loss: 0.2105 - accuracy: 0.9296
Out[ ]:
[0.21050453186035156, 0.9295781850814819]
In [ ]:
# predict the model
y_pred = model.predict(test_generator)

# get the class with the highest probability
y_pred_labels = np.argmax(y_pred, axis=1)

# get the true class
y_true = np.array(test_generator.classes)


# get the class labels
class_labels = list(test_generator.class_indices.keys())

display(classification_report(y_true, y_pred_labels, target_names=class_labels, zero_division=0))
# get the accuracy
accuracy = accuracy_score(y_true, y_pred_labels)
print(f'Accuracy: {accuracy}')
601/601 [==============================] - 27s 44ms/step
'                                             precision    recall  f1-score   support\n\n                           Apple Apple scab       0.99      0.96      0.97       908\n                            Apple Black rot       0.99      0.99      0.99       895\n                     Apple Cedar apple rust       0.99      0.98      0.99       396\n                              Apple healthy       0.96      0.96      0.96       593\n         Bacterial leaf blight in rice leaf       1.00      1.00      1.00        18\n                        Blight in corn Leaf       0.90      0.80      0.85       516\n                    Brown spot in rice leaf       0.57      0.67      0.62        18\n     Cherry (including sour) Powdery mildew       0.98      0.97      0.98       379\n            Cherry (including_sour) healthy       0.99      0.94      0.97       308\n                   Common Rust in corn Leaf       0.99      0.91      0.95       588\n                       Corn (maize) healthy       1.00      0.99      0.99       419\n                            Grape Black rot       0.95      0.99      0.97      1700\n                   Grape Esca Black Measles       0.98      1.00      0.99      1993\n     Grape Leaf blight Isariopsis Leaf Spot       0.99      0.96      0.98      1550\n                              Grape healthy       0.99      0.97      0.98       153\n                Gray Leaf Spot in corn Leaf       0.65      0.90      0.75       259\n                     Leaf smut in rice leaf       0.88      0.39      0.54        18\n                 Pepper bell Bacterial spot       0.95      0.92      0.93       449\n                        Pepper bell healthy       0.93      0.95      0.94       533\n                        Potato Early blight       0.95      0.97      0.96       450\n                         Potato Late blight       0.93      0.88      0.91       450\n                             Potato healthy       0.91      0.55      0.68        55\n                     Strawberry Leaf scorch       0.96      0.96      0.96       400\n                         Strawberry healthy       0.99      0.84      0.90       165\n                      Tomato Bacterial spot       0.95      0.91      0.93       958\n                        Tomato Early blight       0.83      0.88      0.85       450\n                         Tomato Late blight       0.95      0.81      0.88       860\n                           Tomato Leaf Mold       0.94      0.84      0.89       429\n                  Tomato Septoria leaf spot       0.91      0.88      0.89       797\nTomato Spider mites Two spotted spider mite       0.84      0.91      0.87       755\n                         Tomato Target Spot       0.75      0.91      0.82       632\n                 Tomato Tomato mosaic virus       0.65      0.99      0.79       168\n                             Tomato healthy       0.91      0.99      0.95       573\n                          algal leaf in tea       0.39      1.00      0.56        51\n                         anthracnose in tea       0.72      0.51      0.60        45\n                       bird eye spot in tea       0.61      0.82      0.70        45\n                        brown blight in tea       1.00      0.16      0.27        51\n                                  corn crop       0.80      0.70      0.75        47\n                           healthy tea leaf       1.00      1.00      1.00        34\n                                potato crop       0.79      0.61      0.69        18\n                        potato hollow heart       0.96      0.89      0.92        27\n                       red leaf spot in tea       1.00      0.46      0.63        65\n                              tomato canker       0.00      0.00      0.00         9\n\n                                   accuracy                           0.93     19227\n                                  macro avg       0.87      0.83      0.83     19227\n                               weighted avg       0.94      0.93      0.93     19227\n'
Accuracy: 0.929578197326676
In [ ]:
print(classification_report(y_true, y_pred_labels, target_names=class_labels, zero_division=0))
                                             precision    recall  f1-score   support

                           Apple Apple scab       0.99      0.96      0.97       908
                            Apple Black rot       0.99      0.99      0.99       895
                     Apple Cedar apple rust       0.99      0.98      0.99       396
                              Apple healthy       0.96      0.96      0.96       593
         Bacterial leaf blight in rice leaf       1.00      1.00      1.00        18
                        Blight in corn Leaf       0.90      0.80      0.85       516
                    Brown spot in rice leaf       0.57      0.67      0.62        18
     Cherry (including sour) Powdery mildew       0.98      0.97      0.98       379
            Cherry (including_sour) healthy       0.99      0.94      0.97       308
                   Common Rust in corn Leaf       0.99      0.91      0.95       588
                       Corn (maize) healthy       1.00      0.99      0.99       419
                            Grape Black rot       0.95      0.99      0.97      1700
                   Grape Esca Black Measles       0.98      1.00      0.99      1993
     Grape Leaf blight Isariopsis Leaf Spot       0.99      0.96      0.98      1550
                              Grape healthy       0.99      0.97      0.98       153
                Gray Leaf Spot in corn Leaf       0.65      0.90      0.75       259
                     Leaf smut in rice leaf       0.88      0.39      0.54        18
                 Pepper bell Bacterial spot       0.95      0.92      0.93       449
                        Pepper bell healthy       0.93      0.95      0.94       533
                        Potato Early blight       0.95      0.97      0.96       450
                         Potato Late blight       0.93      0.88      0.91       450
                             Potato healthy       0.91      0.55      0.68        55
                     Strawberry Leaf scorch       0.96      0.96      0.96       400
                         Strawberry healthy       0.99      0.84      0.90       165
                      Tomato Bacterial spot       0.95      0.91      0.93       958
                        Tomato Early blight       0.83      0.88      0.85       450
                         Tomato Late blight       0.95      0.81      0.88       860
                           Tomato Leaf Mold       0.94      0.84      0.89       429
                  Tomato Septoria leaf spot       0.91      0.88      0.89       797
Tomato Spider mites Two spotted spider mite       0.84      0.91      0.87       755
                         Tomato Target Spot       0.75      0.91      0.82       632
                 Tomato Tomato mosaic virus       0.65      0.99      0.79       168
                             Tomato healthy       0.91      0.99      0.95       573
                          algal leaf in tea       0.39      1.00      0.56        51
                         anthracnose in tea       0.72      0.51      0.60        45
                       bird eye spot in tea       0.61      0.82      0.70        45
                        brown blight in tea       1.00      0.16      0.27        51
                                  corn crop       0.80      0.70      0.75        47
                           healthy tea leaf       1.00      1.00      1.00        34
                                potato crop       0.79      0.61      0.69        18
                        potato hollow heart       0.96      0.89      0.92        27
                       red leaf spot in tea       1.00      0.46      0.63        65
                              tomato canker       0.00      0.00      0.00         9

                                   accuracy                           0.93     19227
                                  macro avg       0.87      0.83      0.83     19227
                               weighted avg       0.94      0.93      0.93     19227

Above, we can see how the model performed. We have an accuracy of 96.9% on the test data. And then we can also see how it performed on each class.

As expected, there is a better performance on the classes with more data, and the dataset is somewhat imbalanced.

But an impressive performance none the less.

In [ ]:
# dispalay the classification report in a dataframe
report = classification_report(y_true, y_pred_labels, target_names=class_labels, zero_division=0, output_dict=True)
report_df = pd.DataFrame(report).transpose()
report_df
Out[ ]:
precision recall f1-score support
Apple Apple scab 0.988649 0.959251 0.973728 908.000000
Apple Black rot 0.988839 0.989944 0.989391 895.000000
Apple Cedar apple rust 0.994898 0.984848 0.989848 396.000000
Apple healthy 0.964407 0.959528 0.961961 593.000000
Bacterial leaf blight in rice leaf 1.000000 1.000000 1.000000 18.000000
Blight in corn Leaf 0.903720 0.800388 0.848921 516.000000
Brown spot in rice leaf 0.571429 0.666667 0.615385 18.000000
Cherry (including sour) Powdery mildew 0.983914 0.968338 0.976064 379.000000
Cherry (including_sour) healthy 0.989761 0.941558 0.965058 308.000000
Common Rust in corn Leaf 0.985321 0.913265 0.947926 588.000000
Corn (maize) healthy 0.995181 0.985680 0.990408 419.000000
Grape Black rot 0.953488 0.988824 0.970835 1700.000000
Grape Esca Black Measles 0.982178 0.995484 0.988786 1993.000000
Grape Leaf blight Isariopsis Leaf Spot 0.993968 0.956774 0.975016 1550.000000
Grape healthy 0.993333 0.973856 0.983498 153.000000
Gray Leaf Spot in corn Leaf 0.651685 0.895753 0.754472 259.000000
Leaf smut in rice leaf 0.875000 0.388889 0.538462 18.000000
Pepper bell Bacterial spot 0.949309 0.917595 0.933182 449.000000
Pepper bell healthy 0.926740 0.949343 0.937905 533.000000
Potato Early blight 0.953947 0.966667 0.960265 450.000000
Potato Late blight 0.934272 0.884444 0.908676 450.000000
Potato healthy 0.909091 0.545455 0.681818 55.000000
Strawberry Leaf scorch 0.957816 0.965000 0.961395 400.000000
Strawberry healthy 0.985714 0.836364 0.904918 165.000000
Tomato Bacterial spot 0.947712 0.908142 0.927505 958.000000
Tomato Early blight 0.825996 0.875556 0.850054 450.000000
Tomato Late blight 0.952251 0.811628 0.876334 860.000000
Tomato Leaf Mold 0.940104 0.841492 0.888069 429.000000
Tomato Septoria leaf spot 0.907672 0.875784 0.891443 797.000000
Tomato Spider mites Two spotted spider mite 0.838002 0.911258 0.873096 755.000000
Tomato Target Spot 0.752632 0.905063 0.821839 632.000000
Tomato Tomato mosaic virus 0.649805 0.994048 0.785882 168.000000
Tomato healthy 0.908654 0.989529 0.947368 573.000000
algal leaf in tea 0.389313 1.000000 0.560440 51.000000
anthracnose in tea 0.718750 0.511111 0.597403 45.000000
bird eye spot in tea 0.606557 0.822222 0.698113 45.000000
brown blight in tea 1.000000 0.156863 0.271186 51.000000
corn crop 0.804878 0.702128 0.750000 47.000000
healthy tea leaf 1.000000 1.000000 1.000000 34.000000
potato crop 0.785714 0.611111 0.687500 18.000000
potato hollow heart 0.960000 0.888889 0.923077 27.000000
red leaf spot in tea 1.000000 0.461538 0.631579 65.000000
tomato canker 0.000000 0.000000 0.000000 9.000000
accuracy 0.929578 0.929578 0.929578 0.929578
macro avg 0.870249 0.830239 0.831135 19227.000000
weighted avg 0.936670 0.929578 0.929729 19227.000000
In [ ]:
model.save('./models/plant-disease-vanilla.h5')
In [ ]:
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D

vgg_model = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
In [ ]:
# Freeze the base model
vgg_model.trainable = False

# Add custom layers on top of the base model
x = vgg_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
predictions = Dense(43, activation='softmax')(x)

# Create the model
model = Model(inputs=vgg_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

vgg_callbacks = [
    callbacks.ModelCheckpoint(filepath='./models/new/disease-vgg16.h5', save_best_only=True),
    callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=1),
    callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, verbose=1, min_lr=1e-5)
]

# Training the model
history = model.fit(train_generator, validation_data=valid_generator, epochs=10, callbacks=vgg_callbacks)
Epoch 1/10
2802/2802 [==============================] - 1099s 391ms/step - loss: 1.2507 - accuracy: 0.6076 - val_loss: 0.7590 - val_accuracy: 0.7446 - lr: 0.0010
Epoch 2/10
2802/2802 [==============================] - 1111s 397ms/step - loss: 0.8962 - accuracy: 0.7041 - val_loss: 0.6956 - val_accuracy: 0.7584 - lr: 0.0010
Epoch 3/10
2802/2802 [==============================] - 1128s 402ms/step - loss: 0.8191 - accuracy: 0.7258 - val_loss: 0.7252 - val_accuracy: 0.7527 - lr: 0.0010
Epoch 4/10
2802/2802 [==============================] - 1127s 402ms/step - loss: 0.7713 - accuracy: 0.7394 - val_loss: 0.6566 - val_accuracy: 0.7730 - lr: 0.0010
Epoch 5/10
2802/2802 [==============================] - 1170s 417ms/step - loss: 0.7457 - accuracy: 0.7463 - val_loss: 0.6151 - val_accuracy: 0.7874 - lr: 0.0010
Epoch 6/10
2802/2802 [==============================] - 1163s 415ms/step - loss: 0.7221 - accuracy: 0.7550 - val_loss: 0.6318 - val_accuracy: 0.7811 - lr: 0.0010
Epoch 7/10
2802/2802 [==============================] - 1165s 416ms/step - loss: 0.7025 - accuracy: 0.7595 - val_loss: 0.6020 - val_accuracy: 0.7884 - lr: 0.0010
Epoch 8/10
2802/2802 [==============================] - 1174s 419ms/step - loss: 0.6899 - accuracy: 0.7642 - val_loss: 0.6087 - val_accuracy: 0.7856 - lr: 0.0010
Epoch 9/10
2802/2802 [==============================] - 1200s 428ms/step - loss: 0.6746 - accuracy: 0.7687 - val_loss: 0.5829 - val_accuracy: 0.7939 - lr: 0.0010
Epoch 10/10
2802/2802 [==============================] - 1176s 420ms/step - loss: 0.6624 - accuracy: 0.7742 - val_loss: 0.5192 - val_accuracy: 0.8160 - lr: 0.0010
In [ ]:
history_df = pd.DataFrame(history.history)
history_df.insert(0, 'epoch', range(1, len(history_df) + 1))
history_df
Out[ ]:
epoch loss accuracy val_loss val_accuracy lr
0 1 1.250749 0.607628 0.759000 0.744600 0.001
1 2 0.896184 0.704090 0.695574 0.758393 0.001
2 3 0.819092 0.725798 0.725238 0.752720 0.001
3 4 0.771268 0.739363 0.656598 0.773018 0.001
4 5 0.745735 0.746347 0.615140 0.787384 0.001
5 6 0.722142 0.754981 0.631768 0.781138 0.001
6 7 0.702497 0.759521 0.601996 0.788372 0.001
7 8 0.689927 0.764229 0.608651 0.785614 0.001
8 9 0.674576 0.768669 0.582919 0.793890 0.001
9 10 0.662447 0.774246 0.519231 0.815958 0.001
In [ ]:
# Create a DataFrame from the history object
history_df = pd.DataFrame(history.history)

# Plot the training and validation loss
plt.figure(figsize=(9, 5))
values = history_df['accuracy']
epochs = range(1, len(values) + 1)
plt.plot(epochs, history_df['loss'], 'bo', label='Training loss')
plt.plot(epochs, history_df['val_loss'], 'ro', label='Validation loss')

plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and validation loss')
plt.show()

# Plot the training and validation accuracy
plt.figure(figsize=(9, 5))
plt.plot(epochs, history_df['accuracy'], 'bo', label='Training accuracy')
plt.plot(epochs, history_df['val_accuracy'], 'ro', label='Validation accuracy')

plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and validation accuracy')
plt.show()
No description has been provided for this image
No description has been provided for this image
In [ ]:
# evaluate the model
model.evaluate(test_generator)
601/601 [==============================] - 47s 78ms/step - loss: 0.5252 - accuracy: 0.8127
Out[ ]:
[0.5252286791801453, 0.8127112984657288]
In [ ]:
 
In [ ]:
from tensorflow.keras.applications import InceptionV3

inception_model = InceptionV3(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
In [ ]:
# Freeze the base model
inception_model.trainable = False

# Add custom layers on top of the base model
x = inception_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
predictions = Dense(43, activation='softmax')(x)

# Create the model
model = Model(inputs=inception_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

inception_callbacks = [
    callbacks.ModelCheckpoint(filepath='./models/new/disease-inception.h5', save_best_only=True),
    callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=1),
    callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, verbose=1, min_lr=1e-5)
]

# Training the model
history = model.fit(train_generator, validation_data=valid_generator, epochs=10, callbacks=inception_callbacks)
Epoch 1/10
2802/2802 [==============================] - 1126s 401ms/step - loss: 0.9927 - accuracy: 0.6901 - val_loss: 0.6263 - val_accuracy: 0.7841 - lr: 0.0010
Epoch 2/10
2802/2802 [==============================] - 1101s 393ms/step - loss: 0.7499 - accuracy: 0.7537 - val_loss: 0.5716 - val_accuracy: 0.8101 - lr: 0.0010
Epoch 3/10
2802/2802 [==============================] - 1103s 394ms/step - loss: 0.7017 - accuracy: 0.7674 - val_loss: 0.5925 - val_accuracy: 0.8036 - lr: 0.0010
Epoch 4/10
2802/2802 [==============================] - 1101s 393ms/step - loss: 0.6670 - accuracy: 0.7786 - val_loss: 0.5550 - val_accuracy: 0.8085 - lr: 0.0010
Epoch 5/10
2802/2802 [==============================] - 1099s 392ms/step - loss: 0.6453 - accuracy: 0.7860 - val_loss: 0.5064 - val_accuracy: 0.8293 - lr: 0.0010
Epoch 6/10
2802/2802 [==============================] - 1102s 393ms/step - loss: 0.6288 - accuracy: 0.7924 - val_loss: 0.4780 - val_accuracy: 0.8349 - lr: 0.0010
Epoch 7/10
2802/2802 [==============================] - 1109s 396ms/step - loss: 0.6174 - accuracy: 0.7958 - val_loss: 0.4502 - val_accuracy: 0.8485 - lr: 0.0010
Epoch 8/10
2802/2802 [==============================] - 1137s 406ms/step - loss: 0.6040 - accuracy: 0.7991 - val_loss: 0.4677 - val_accuracy: 0.8406 - lr: 0.0010
Epoch 9/10
2802/2802 [==============================] - 1105s 394ms/step - loss: 0.5975 - accuracy: 0.8019 - val_loss: 0.4297 - val_accuracy: 0.8533 - lr: 0.0010
Epoch 10/10
2802/2802 [==============================] - 1105s 394ms/step - loss: 0.5838 - accuracy: 0.8057 - val_loss: 0.4479 - val_accuracy: 0.8491 - lr: 0.0010
In [ ]:
history_df = pd.DataFrame(history.history)
history_df.insert(0, 'epoch', range(1, len(history_df) + 1))
history_df
Out[ ]:
epoch loss accuracy val_loss val_accuracy lr
0 1 0.992702 0.690067 0.626298 0.784052 0.001
1 2 0.749905 0.753698 0.571625 0.810129 0.001
2 3 0.701684 0.767386 0.592548 0.803571 0.001
3 4 0.667020 0.778564 0.555006 0.808463 0.001
4 5 0.645300 0.786004 0.506417 0.829282 0.001
5 6 0.628777 0.792352 0.478001 0.834903 0.001
6 7 0.617358 0.795810 0.450240 0.848488 0.001
7 8 0.604008 0.799067 0.467703 0.840629 0.001
8 9 0.597451 0.801934 0.429734 0.853328 0.001
9 10 0.583820 0.805705 0.447936 0.849061 0.001
In [ ]:
# Create a DataFrame from the history object
history_df = pd.DataFrame(history.history)

# Plot the training and validation loss
plt.figure(figsize=(9, 5))
values = history_df['accuracy']
epochs = range(1, len(values) + 1)
plt.plot(epochs, history_df['loss'], 'bo', label='Training loss')
plt.plot(epochs, history_df['val_loss'], 'ro', label='Validation loss')

plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and validation loss')
plt.show()

# Plot the training and validation accuracy
plt.figure(figsize=(9, 5))
plt.plot(epochs, history_df['accuracy'], 'bo', label='Training accuracy')
plt.plot(epochs, history_df['val_accuracy'], 'ro', label='Validation accuracy')

plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and validation accuracy')
plt.show()
No description has been provided for this image
No description has been provided for this image
In [ ]:
# evaluate the model
model.evaluate(test_generator)
601/601 [==============================] - 40s 66ms/step - loss: 0.4381 - accuracy: 0.8506
Out[ ]:
[0.43813350796699524, 0.8505747318267822]

Future Work¶

Going forward, and in an attempt to improve this machine learning process, there are a few things that we can consider doing.

  1. Improving the dataset. More images to train on means that the model can learn better and perform even better.

  2. Balancing the dataset. One way that will definitely improve our model performance would be to get more images for the classes where we currently have just few of them.

  3. Using even more advanced models for transfer learning. In addition to what has been mentioned, we could also use popular models like VGG16, VGG19, ResNet50, Inception and the likes. But this calls for more GPU processing power.